import java.util.Arrays;
import java.util.Scanner;

public class test2 {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int[] nums = new int[n];
        for (int i = 0; i < n; i++) {
            nums[i] = in.nextInt();
        }
        System.out.println(getResult(nums));
    }
    public static int getResult(int[] nums){
        if(nums==null||nums.length==0)
            return 0;
        int min = Arrays.stream(nums).min().getAsInt();
        int max = Arrays.stream(nums).max().getAsInt();
        int n = max-min+1;

        int[] flag = new int[n];
        for (int num:nums)
            flag[num-min]++;

        int[] dp = new int[n];
        dp[0] = (min+0)*flag[0];//第一个数字

        if(n==1)
            return dp[0];

        dp[1] = Math.max(dp[0],(min+1)*flag[1]);

        for (int i=2;i<n;i++){
            dp[i] = Math.max(dp[i-1],dp[i-2]+(min+i)*flag[i]);
        }
        return dp[n-1];
    }
}
